# Copyright (c) OpenMMLab. All rights reserved.
import pytest
import torch

from mmcv.ops import grouping_operation
from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE


@pytest.mark.parametrize('device', [
    pytest.param(
        'cuda',
        marks=pytest.mark.skipif(
            not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
    pytest.param(
        'npu',
        marks=pytest.mark.skipif(
            not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
@pytest.mark.parametrize('dtype', [torch.half, torch.float, torch.double])
def test_grouping_points(dtype, device):
    idx = torch.tensor([[[0, 0, 0], [3, 3, 3], [8, 8, 8], [0, 0, 0], [0, 0, 0],
                         [0, 0, 0]],
                        [[0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0], [0, 0, 0],
                         [0, 0, 0]]]).int().to(device)
    features = torch.tensor([[[
        0.5798, -0.7981, -0.9280, -1.3311, 1.3687, 0.9277, -0.4164, -1.8274,
        0.9268, 0.8414
    ],
                              [
                                  5.4247, 1.5113, 2.3944, 1.4740, 5.0300,
                                  5.1030, 1.9360, 2.1939, 2.1581, 3.4666
                              ],
                              [
                                  -1.6266, -1.0281, -1.0393, -1.6931, -1.3982,
                                  -0.5732, -1.0830, -1.7561, -1.6786, -1.6967
                              ]],
                             [[
                                 -0.0380, -0.1880, -1.5724, 0.6905, -0.3190,
                                 0.7798, -0.3693, -0.9457, -0.2942, -1.8527
                             ],
                              [
                                  1.1773, 1.5009, 2.6399, 5.9242, 1.0962,
                                  2.7346, 6.0865, 1.5555, 4.3303, 2.8229
                              ],
                              [
                                  -0.6646, -0.6870, -0.1125, -0.2224, -0.3445,
                                  -1.4049, 0.4990, -0.7037, -0.9924, 0.0386
                              ]]],
                            dtype=dtype).to(device)
    features.requires_grad = True

    output = grouping_operation(features, idx)
    output.backward(output)
    grad_features = features.grad
    expected_output = torch.tensor(
        [[[[0.5798, 0.5798, 0.5798], [-1.3311, -1.3311, -1.3311],
           [0.9268, 0.9268, 0.9268], [0.5798, 0.5798, 0.5798],
           [0.5798, 0.5798, 0.5798], [0.5798, 0.5798, 0.5798]],
          [[5.4247, 5.4247, 5.4247], [1.4740, 1.4740, 1.4740],
           [2.1581, 2.1581, 2.1581], [5.4247, 5.4247, 5.4247],
           [5.4247, 5.4247, 5.4247], [5.4247, 5.4247, 5.4247]],
          [[-1.6266, -1.6266, -1.6266], [-1.6931, -1.6931, -1.6931],
           [-1.6786, -1.6786, -1.6786], [-1.6266, -1.6266, -1.6266],
           [-1.6266, -1.6266, -1.6266], [-1.6266, -1.6266, -1.6266]]],
         [[[-0.0380, -0.0380, -0.0380], [-0.3693, -0.3693, -0.3693],
           [-1.8527, -1.8527, -1.8527], [-0.0380, -0.0380, -0.0380],
           [-0.0380, -0.0380, -0.0380], [-0.0380, -0.0380, -0.0380]],
          [[1.1773, 1.1773, 1.1773], [6.0865, 6.0865, 6.0865],
           [2.8229, 2.8229, 2.8229], [1.1773, 1.1773, 1.1773],
           [1.1773, 1.1773, 1.1773], [1.1773, 1.1773, 1.1773]],
          [[-0.6646, -0.6646, -0.6646], [0.4990, 0.4990, 0.4990],
           [0.0386, 0.0386, 0.0386], [-0.6646, -0.6646, -0.6646],
           [-0.6646, -0.6646, -0.6646], [-0.6646, -0.6646, -0.6646]]]],
        dtype=dtype).to(device)
    expected_grad_features = torch.tensor(
        [[[
            6.9576, 0.0000, 0.0000, -3.9933, 0.0000, 0.0000, 0.0000, 0.0000,
            2.7804, 0.0000
        ],
          [
              65.0964, 0.0000, 0.0000, 4.4220, 0.0000, 0.0000, 0.0000, 0.0000,
              6.4743, 0.0000
          ],
          [
              -19.5192, 0.0000, 0.0000, -5.0793, 0.0000, 0.0000, 0.0000,
              0.0000, -5.0358, 0.0000
          ]],
         [[
             -0.4560, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, -1.1079, 0.0000,
             0.0000, -5.5581
         ],
          [
              14.1276, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 18.2595, 0.0000,
              0.0000, 8.4687
          ],
          [
              -7.9752, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.4970, 0.0000,
              0.0000, 0.1158
          ]]],
        dtype=dtype).to(device)
    assert torch.allclose(output, expected_output)
    assert torch.allclose(grad_features, expected_grad_features)


@pytest.mark.parametrize('device', [
    pytest.param(
        'cuda',
        marks=pytest.mark.skipif(
            not IS_CUDA_AVAILABLE, reason='requires CUDA support')),
    pytest.param(
        'npu',
        marks=pytest.mark.skipif(
            not IS_NPU_AVAILABLE, reason='requires NPU support'))
])
@pytest.mark.parametrize('dtype', [torch.half, torch.float, torch.double])
def test_stack_grouping_points(dtype, device):
    if device == 'npu' and dtype == torch.double:
        return
    idx = torch.tensor([[0, 0, 0], [3, 3, 3], [8, 8, 8], [1, 1, 1], [0, 0, 0],
                        [2, 2, 2], [0, 0, 0], [6, 6, 6], [9, 9, 9], [0, 0, 0],
                        [1, 1, 1], [0, 0, 0]]).int().to(device)
    features = torch.tensor([[
        0.5798, -0.7981, -0.9280, -1.3311, 1.3687, 0.9277, -0.4164, -1.8274,
        0.9268, 0.8414
    ],
                             [
                                 5.4247, 1.5113, 2.3944, 1.4740, 5.0300,
                                 5.1030, 1.9360, 2.1939, 2.1581, 3.4666
                             ],
                             [
                                 -1.6266, -1.0281, -1.0393, -1.6931, -1.3982,
                                 -0.5732, -1.0830, -1.7561, -1.6786, -1.6967
                             ],
                             [
                                 -0.0380, -0.1880, -1.5724, 0.6905, -0.3190,
                                 0.7798, -0.3693, -0.9457, -0.2942, -1.8527
                             ],
                             [
                                 1.1773, 1.5009, 2.6399, 5.9242, 1.0962,
                                 2.7346, 6.0865, 1.5555, 4.3303, 2.8229
                             ],
                             [
                                 -0.6646, -0.6870, -0.1125, -0.2224, -0.3445,
                                 -1.4049, 0.4990, -0.7037, -0.9924, 0.0386
                             ]],
                            dtype=dtype).to(device)
    features_batch_cnt = torch.tensor([3, 3]).int().to(device)
    indices_batch_cnt = torch.tensor([6, 6]).int().to(device)
    output = grouping_operation(features, idx, features_batch_cnt,
                                indices_batch_cnt)
    expected_output = torch.tensor(
        [[[0.5798, 0.5798, 0.5798], [-0.7981, -0.7981, -0.7981],
          [-0.9280, -0.9280, -0.9280], [-1.3311, -1.3311, -1.3311],
          [1.3687, 1.3687, 1.3687], [0.9277, 0.9277, 0.9277],
          [-0.4164, -0.4164, -0.4164], [-1.8274, -1.8274, -1.8274],
          [0.9268, 0.9268, 0.9268], [0.8414, 0.8414, 0.8414]],
         [[0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000]],
         [[0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000]],
         [[5.4247, 5.4247, 5.4247], [1.5113, 1.5113, 1.5113],
          [2.3944, 2.3944, 2.3944], [1.4740, 1.4740, 1.4740],
          [5.0300, 5.0300, 5.0300], [5.1030, 5.1030, 5.1030],
          [1.9360, 1.9360, 1.9360], [2.1939, 2.1939, 2.1939],
          [2.1581, 2.1581, 2.1581], [3.4666, 3.4666, 3.4666]],
         [[0.5798, 0.5798, 0.5798], [-0.7981, -0.7981, -0.7981],
          [-0.9280, -0.9280, -0.9280], [-1.3311, -1.3311, -1.3311],
          [1.3687, 1.3687, 1.3687], [0.9277, 0.9277, 0.9277],
          [-0.4164, -0.4164, -0.4164], [-1.8274, -1.8274, -1.8274],
          [0.9268, 0.9268, 0.9268], [0.8414, 0.8414, 0.8414]],
         [[-1.6266, -1.6266, -1.6266], [-1.0281, -1.0281, -1.0281],
          [-1.0393, -1.0393, -1.0393], [-1.6931, -1.6931, -1.6931],
          [-1.3982, -1.3982, -1.3982], [-0.5732, -0.5732, -0.5732],
          [-1.0830, -1.0830, -1.0830], [-1.7561, -1.7561, -1.7561],
          [-1.6786, -1.6786, -1.6786], [-1.6967, -1.6967, -1.6967]],
         [[-0.0380, -0.0380, -0.0380], [-0.1880, -0.1880, -0.1880],
          [-1.5724, -1.5724, -1.5724], [0.6905, 0.6905, 0.6905],
          [-0.3190, -0.3190, -0.3190], [0.7798, 0.7798, 0.7798],
          [-0.3693, -0.3693, -0.3693], [-0.9457, -0.9457, -0.9457],
          [-0.2942, -0.2942, -0.2942], [-1.8527, -1.8527, -1.8527]],
         [[0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000]],
         [[0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000], [0.0000, 0.0000, 0.0000]],
         [[-0.0380, -0.0380, -0.0380], [-0.1880, -0.1880, -0.1880],
          [-1.5724, -1.5724, -1.5724], [0.6905, 0.6905, 0.6905],
          [-0.3190, -0.3190, -0.3190], [0.7798, 0.7798, 0.7798],
          [-0.3693, -0.3693, -0.3693], [-0.9457, -0.9457, -0.9457],
          [-0.2942, -0.2942, -0.2942], [-1.8527, -1.8527, -1.8527]],
         [[1.1773, 1.1773, 1.1773], [1.5009, 1.5009, 1.5009],
          [2.6399, 2.6399, 2.6399], [5.9242, 5.9242, 5.9242],
          [1.0962, 1.0962, 1.0962], [2.7346, 2.7346, 2.7346],
          [6.0865, 6.0865, 6.0865], [1.5555, 1.5555, 1.5555],
          [4.3303, 4.3303, 4.3303], [2.8229, 2.8229, 2.8229]],
         [[-0.0380, -0.0380, -0.0380], [-0.1880, -0.1880, -0.1880],
          [-1.5724, -1.5724, -1.5724], [0.6905, 0.6905, 0.6905],
          [-0.3190, -0.3190, -0.3190], [0.7798, 0.7798, 0.7798],
          [-0.3693, -0.3693, -0.3693], [-0.9457, -0.9457, -0.9457],
          [-0.2942, -0.2942, -0.2942], [-1.8527, -1.8527, -1.8527]]],
        dtype=dtype).to(device)
    assert torch.allclose(output, expected_output)
